import numpy as np
import torch
import os 
import pickle 
from skimage.transform import resize
from torchvision import transforms
from PIL import Image
import cv2
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class Dataset_visual(torch.utils.data.Dataset):

    def __init__(self, subject=8, mode = 'train', brain_region = 'ventral_visual_data', dim=(224, 224),  n_channels = 3, 
                data_path = 'data/'):
           
        with open(os.path.join(data_path,brain_region + '_' + str(subject) +'_splits.pickle'), 'rb') as pickle_file:
            self.splits = pickle.load(pickle_file) 
        self.ids = self.splits[mode]
        self.data_path = data_path
        with open(os.path.join(data_path,'ffa_data.pickle'), 'rb') as pickle_file:
            self.stimuli_response_dict = pickle.load(pickle_file) 
        self.resp_sizes = 794
        self.dim = dim
        self.total_size = len(self.ids) 
        self.n_channels = n_channels
        self.n_neurons =  self.resp_sizes
        self.subject=subject


    def __len__(self):
        'Denotes the total number of samples'
        return len(self.ids) 
  
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        data_ix = self.ids[index]
        X = Image.open(os.path.join(self.data_path+'/images',str(data_ix)+'.jpg'))
        # X = np.asarray(image).astype('float32') 
        y = self.stimuli_response_dict[data_ix][self.subject]
        X = preprocess(X)
        return X, y



    
